# Copyright (C) 2025 Arcee AI
# SPDX-License-Identifier: LGPL-3.0-only

import logging
from typing import Dict, Optional

import torch

from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.tokenizer.build import BuildTokenizer, TokenizerInfo
from mergekit.tokenizer.config import (
    ModelTokenEmbedding,
    TokenEmbeddingConfig,
    ZeroEmbedding,
)


class PermutedEmbeddings(Task[Dict[ModelReference, torch.Tensor]]):
    gather_tensors: GatherTensors
    tokenizer_task: BuildTokenizer
    tokens: Optional[ImmutableMap[str, TokenEmbeddingConfig]]
    pad_to_multiple_of: Optional[int]
    base_model: Optional[ModelReference]

    def arguments(self) -> Dict[str, Task]:
        return {"tokenizer_info": self.tokenizer_task, "tensors": self.gather_tensors}

    def execute(
        self, tokenizer_info: TokenizerInfo, tensors: Dict[ModelReference, torch.Tensor]
    ) -> Dict[ModelReference, torch.Tensor]:
        tokenizer = tokenizer_info.tokenizer
        permutations = tokenizer_info.permutations

        models = set(tensors.keys())
        if self.base_model:
            models.add(self.base_model)
        models = list(models)

        vocab = tokenizer.get_vocab()
        vocab_size = len(vocab)
        if self.pad_to_multiple_of and vocab_size % self.pad_to_multiple_of:
            vocab_size = (
                vocab_size // self.pad_to_multiple_of + 1
            ) * self.pad_to_multiple_of
        embed_size = tensors[models[0]].shape[1]
        assert all(
            t.shape[1] == embed_size for t in tensors.values()
        ), "Embedding sizes must match"

        dtype = tensors[models[0]].dtype
        device = tensors[models[0]].device

        token_configs = dict(**(self.tokens or {}))
        tokens_to_average = self.assign_embedding_sources(
            permutations, models, vocab, token_configs
        )

        default_embeds = {}
        for token, token_id in vocab.items():
            embed = torch.zeros(embed_size, dtype=dtype, device=device)
            if token in tokens_to_average:
                count = 0
                for model in models:
                    p = permutations[model]
                    if p[token_id] < 0:
                        continue
                    embed += tensors[model][p[token_id]]
                    count += 1
                embed /= count
            elif cfg := token_configs.get(token, None):
                cfg: TokenEmbeddingConfig
                embed = self.compute_default_embedding(
                    tokenizer_info, tensors, permutations, token, token_id, cfg
                )
            else:
                continue
            default_embeds[token] = embed

        result = {}
        for model in models:
            p = permutations[model]
            old_embed = tensors[model]
            new_embed = torch.zeros(
                (vocab_size, embed_size), dtype=dtype, device=device
            )
            for token, token_id in vocab.items():
                force = False
                if token in token_configs:
                    force = token_configs[token].force

                if p[token_id] >= 0 and not force:
                    new_embed[token_id, :] = old_embed[p[token_id]]
                elif token in default_embeds:
                    new_embed[token_id, :] = default_embeds[token]
                else:
                    logging.error(
                        f"No embedding for token {repr(token)} in model {model}!"
                    )

            if vocab_size > len(vocab):
                # as suggested by https://nlp.stanford.edu/~johnhew/vocab-expansion.html
                avg_embed = torch.mean(new_embed[: len(vocab), :], dim=0)
                new_embed[len(vocab) :, :] = avg_embed
            result[model] = new_embed

        return result

    def assign_embedding_sources(
        self,
        permutations: Dict[ModelReference, Dict[int, int]],
        models: list[ModelReference],
        vocab: Dict[str, int],
        token_configs: Dict[str, TokenEmbeddingConfig],
    ):
        permutation_list = [permutations[model] for model in models]

        tokens_to_average = set()
        # find tokens that are only present in one model
        for token, token_id in vocab.items():
            if token in token_configs:
                continue

            has_token = [p[token_id] >= 0 for p in permutation_list]
            num_present = sum(int(x) for x in has_token)
            if num_present == 1:
                donor_model = models[has_token.index(True)]
                token_configs[token] = TokenEmbeddingConfig(source=donor_model)
                continue

            if num_present == 0:
                token_configs[token] = TokenEmbeddingConfig(source=ZeroEmbedding())
                logging.warning(f"Token {repr(token)} not found in any model")
                continue

            if num_present > 0 and self.base_model is not None:
                if permutations[self.base_model][token_id] >= 0:
                    token_configs[token] = TokenEmbeddingConfig(source=self.base_model)
                    continue

            tokens_to_average.add(token)
        return tokens_to_average

    def compute_default_embedding(
        self,
        tokenizer_info: TokenizerInfo,
        tensors: Dict[ModelReference, torch.Tensor],
        permutations: Dict[ModelReference, Dict[int, int]],
        token: str,
        token_id: int,
        cfg: TokenEmbeddingConfig,
    ) -> torch.Tensor:
        if isinstance(cfg.source, ZeroEmbedding):
            pass
        elif isinstance(cfg.source, ModelTokenEmbedding):
            model = cfg.source.model
            assert (
                model in permutations
            ), f"Model {model} referenced but not part of merge"
            p = permutations[model]
            src_token_id = cfg.source.token_id
            if src_token_id is None:
                src_token = cfg.source.token
                assert (
                    src_token in tokenizer_info.original_vocabs[model]
                ), f"Token {repr(src_token)} not found in model {model}"
                src_token_id = tokenizer_info.original_vocabs[model][src_token]
            assert (
                src_token_id >= 0 and src_token_id < tensors[model].shape[0]
            ), f"Token ID {src_token_id} out of range for model {model}"
            embed = tensors[model][src_token_id]
        elif isinstance(cfg.source, ModelReference):
            model = cfg.source
            p = permutations[model]
            assert p[token_id] >= 0, f"Token {repr(token)} not found in model {model}"
            embed = tensors[model][p[token_id]]
        else:
            raise NotImplementedError(cfg)
        return embed
